# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes for extracting profiling information from simpleperf record files.

Example:
    analyzer = RecordAnalyzer()
    analyzer.analyze('perf.data')

    for event_name, event_count in analyzer.event_counts.items():
        print(f'Number of {event_name} events: {event_count}')
"""

import collections
import logging
import sys

from typing import DefaultDict, Dict, Iterable, Iterator, Optional

# Disable import-error as simpleperf_report_lib is not in pylint's `sys.path`
# pylint: disable=import-error
import simpleperf_report_lib  # type: ignore


class Instruction:
    """Instruction records profiling information for an assembly instruction.

    Attributes:
        relative_addr (int): The address of an instruction relative to the
            start of its method. For arm64, the first instruction of a method
            will be at the relative address 0, the second at the relative
            address 4, and so on.
        event_counts (DefaultDict[str, int]): A mapping of event names to their
            total number of events for this instruction.
    """

    def __init__(self, relative_addr: int) -> None:
        """Instantiates an Instruction.

        Args:
            relative_addr (int): A relative address.
        """
        self.relative_addr = relative_addr

        self.event_counts: DefaultDict[str, int] = collections.defaultdict(int)

    def record_sample(self, event_name: str, event_count: int) -> None:
        """Records profiling information given by a sample.

        Args:
            event_name (str): An event name.
            event_count (int): An event count.
        """
        self.event_counts[event_name] += event_count


class Method:
    """Method records profiling information for a compiled method.

    Attributes:
        name (str): A method name.
        event_counts (DefaultDict[str, int]): A mapping of event names to their
            total number of events for this method.
        instructions (Dict[int, Instruction]): A mapping of relative
            instruction addresses to their Instruction object.
    """

    def __init__(self, name: str) -> None:
        """Instantiates a Method.

        Args:
            name (str): A method name.
        """
        self.name = name

        self.event_counts: DefaultDict[str, int] = collections.defaultdict(int)
        self.instructions: Dict[int, Instruction] = {}

    def record_sample(self, relative_addr: int, event_name: str,
                      event_count: int) -> None:
        """Records profiling information given by a sample.

        Args:
            relative_addr (int): The relative address of an instruction hit.
            event_name (str): An event name.
            event_count (int): An event count.
        """
        self.event_counts[event_name] += event_count

        if relative_addr not in self.instructions:
            self.instructions[relative_addr] = Instruction(relative_addr)

        instruction = self.instructions[relative_addr]
        instruction.record_sample(event_name, event_count)


class RecordAnalyzer:
    """RecordAnalyzer extracts profiling information from simpleperf record
    files.

    Multiple record files can be analyzed successively, each containing one or
    more event types. Samples from odex files are the only ones analyzed, as
    we're interested by the performance of methods generated by the optimizing
    compiler.

    Attributes:
        event_names (Set[str]): A set of event names to analyze. If empty, all
            events are analyzed.
        event_counts (DefaultDict[str, int]): A mapping of event names to their
            total number of events for the analyzed samples.
        methods (Dict[str, Method]): A mapping of method names to their Method
            object.
        report (simpleperf_report_lib.ReportLib): A ReportLib object.
        target_arch (str): A target architecture determined from the first
            record file analyzed.
    """

    def __init__(self, event_names: Optional[Iterable[str]] = None) -> None:
        """Instantiates a RecordAnalyzer.

        Args:
            event_names (Optional[Iterable[str]]): An optional iterable of
                event names to analyze. If empty or falsy, all events are
                analyzed.
        """
        if not event_names:
            event_names = []

        self.event_names = set(event_names)

        self.event_counts: DefaultDict[str, int] = collections.defaultdict(int)
        self.methods: Dict[str, Method] = {}
        self.report: simpleperf_report_lib.ReportLib
        self.target_arch = ''

    def analyze(self, filename: str) -> None:
        """Analyzes a perf record file.

        Args:
            filename (str): The path to a perf record file.
        """
        # One ReportLib object needs to be instantiated per record file
        self.report = simpleperf_report_lib.ReportLib()
        self.report.SetRecordFile(filename)

        arch = self.report.GetArch()
        if not self.target_arch:
            self.target_arch = arch
        elif self.target_arch != arch:
            logging.error(
                'Record file %s is for the architecture %s, expected %s',
                filename, arch, self.target_arch)
            self.report.Close()
            sys.exit(1)

        for sample in self.samples():
            event = self.report.GetEventOfCurrentSample()
            if self.event_names and event.name not in self.event_names:
                continue

            symbol = self.report.GetSymbolOfCurrentSample()
            relative_addr = symbol.vaddr_in_file - symbol.symbol_addr
            self.record_sample(symbol.symbol_name, relative_addr, event.name,
                               sample.period)

        self.report.Close()
        logging.info('Analyzed %d event(s) for %d method(s)',
                     len(self.event_counts), len(self.methods))

    def samples(self) -> Iterator[simpleperf_report_lib.SampleStruct]:
        """Iterates over samples for compiled methods located in odex files.

        Yields:
            simpleperf_report_lib.SampleStruct: A sample for a compiled method.
        """
        sample = self.report.GetNextSample()
        while sample:
            symbol = self.report.GetSymbolOfCurrentSample()
            if symbol.dso_name.endswith('.odex'):
                yield sample

            sample = self.report.GetNextSample()

    def record_sample(self, method_name: str, relative_addr: int,
                      event_name: str, event_count: int) -> None:
        """Records profiling information given by a sample.

        Args:
            method_name (str): A method name.
            relative_addr (int): The relative address of an instruction hit.
            event_name (str): An event name.
            event_count (int): An event count.
        """
        self.event_counts[event_name] += event_count

        if method_name not in self.methods:
            self.methods[method_name] = Method(method_name)

        method = self.methods[method_name]
        method.record_sample(relative_addr, event_name, event_count)
